Skip to content

Conversation

@SamuelMarks
Copy link
Collaborator

@SamuelMarks SamuelMarks commented Jun 15, 2025

Description

Background

MaxText is currently engineered around pyconfig. Pyconfig—https://pypi.org/project/pyconfig/—was last updated in 2017 and has 20k monthly downloads.

Pydantic—https://pypi.org/project/pydantic/—is constantly updated and has hundreds of millions of monthly downloads.

Pydantic is the most widely used data validation library for Python.

https://docs.pydantic.dev

Summary

The TL;DR version is:

  • pydantic has become basically the standard for configuration formats, specifying the inputs and outputs (e.g., for REST APIs in FastAPI framework)
  • pydantic links in well with the Python type-checker so you can oft find errors before runtime
  • pydantic errors are clean and clear
  • with my python compiler, you can have clean CLIs with --help and GUIs and generate SQL models (SQLAlchemy) as desired
  • SDK documentation becomes much cleaner (which is good in preparation for the pypi release + hosted doc pages)

This introduces 487 new pydantic.fields.Fields across 78 classes (65 inheriting from pydantic.BaseModel) to replace the old untyped undocumented system.

Migration

Proposed changes to the MaxText codebase:

  1. Completely remove the pyconfig dependency
    1. maybe not immediately so there’s a chance for people to easily migration of their existing setups, e.g., with new functions from_pyconfig_to_pydantic
  2. Migrate all examples, and codebase occurrences with new pydantic types
  3. Put pydantic types on a computer+human understandable hierarchy, e.g.:
    1. One global types.py, or
    2. One types.py per config occurrence (e.g., one per module if each module has a different config)
  4. Create new CLI that uses common CLI syntax (e.g., this can be automatically created using my Python compiler https://github.com/offscale/cdd-python)
  5. Migrate all shell scripts and docs to use this new CLI
    1. TBD: remove the shell scripts in favour of Python SDK usage.

Tests

CI and manual:

$ bash ./dependencies/scripts/docker_build_dependency_image.sh DEVICE='tpu' MODE='nightly'
$ bash ./dependencies/scripts/docker_build_dependency_image.sh DEVICE='tpu' MODE='stable'
$ export MODEL_NAME='llama3_1_70b_8192_synthetic' \
         PROJECT="${GOOGLE_CLOUD_PROJECT?}" \
         ZONE="${GOOGLE_CLOUD_ZONE?}" \
         CLUSTER_NAME="${GOOGLE_CLOUD_CLUSTER_NAME?}" \
         OUTPUT_DIR="${GOOGLE_CLOUD_BUCKET?}" \
         BASE_OUTPUT_DIR="${GOOGLE_CLOUD_BUCKET?}"'/output/' \
         DATASET_PATH="${GOOGLE_CLOUD_BUCKET?}"'/' \
         WORKLOAD='job_name_goes_here'

# Try running every model on TPU VM.
# Once generated, one can loop through `local_runs.txt` on TPU VM
# and not loop through models until either new ones are added or 
# you're on a more powerful TPU VM.
$ for model_name in 'default' 'llama2-7b' 'llama2-13b' 'llama2-70b' 'llama3-8b' 'llama3-70b' 'llama3.1-8b' \
                    'llama3.1-70b' 'llama3.1-405b' 'llama3.3-70b' 'mistral-7b' 'mixtral-8x7b' \
                    'mixtral-8x22b' 'deepseek2-16b' 'deepseek2-236b' 'deepseek3-671b' \
                    'deepseek3-test' 'deepseek3-tiny' 'kimi-k2-1t' 'gemma-7b' 'gemma-2b' 'gemma2-2b' \
                    'gemma2-9b' 'gemma2-27b' 'gemma3-4b' 'gemma3-12b' 'gemma3-27b' 'qwen3-0.6b' \
                    'qwen3-4b' 'qwen3-4b-thinking-2507' 'qwen3-8b' 'qwen3-14b' 'qwen3-32b' \
                    'qwen3-235b-a22b' 'qwen3-30b-a3b' 'qwen3-480b-a35b' 'qwen3-next-80b-a3b' \
                    'gpt3-175b' 'gpt3-22b' 'gpt3-6b' 'gpt3-52k' 'gpt-oss-20b' 'gpt-oss-120b' \
                    'llama4-17b-16e' 'llama4-17b-128e'; do
  python3 -m MaxText.train MaxText/configs/base.yml \
      run_name="${USER}"'_'"${model_name}"'_002' \
      base_output_directory="${OUTPUT_DIR?}" \
      dataset_type='synthetic' \
      steps='10' \
      model_name="$model_name" && \
  printf '%s\n' "$model_name" >> 'successful_local_runs.txt' || \
  printf '%s\n' "$model_name" >> 'failed_local_runs.txt'
done

$ wc -l 'successful_local_runs.txt'
10 successful_local_runs.txt

$ cat 'successful_local_runs.txt'
# [… omitted in lieu of succeeding markdown list]

$ printf -v command 'python3 -m MaxText.train MaxText/configs/base.yml base_output_directory='"'"'%s'"'"' dataset_path='"'"'%s'"'"' steps='"'"'%d'"'"' per_device_batch_size='"'"'%d'"'" \
  "${BASE_OUTPUT_DIR?}" "${DATASET_PATH?}" '100' '1'

$ xpk workload create \
      --base-docker-image 'maxtext_base_image' \
      --zone "${ZONE?}" \
      --cluster "${CLUSTER_NAME?}" \
      --workload "${WORKLOAD?}" \
      --tpu-type='v6e-256' \
      --num-slices='1' \
      --command "${command?}"

# Try running every model on TPU cluster.
# Once generated, one can loop through `successful_cluster_runs.txt`
# and not loop through models until either new ones are added or 
# you're on different cluster hardware.
$ for model_name in 'default_basic_1' 'default_32' 'default_64' 'default_128' 'default_256' \
                    'default_512' 'gpt_3_175b' 'gpt_3_175b_bf16' 'llama2_7b_4096' \
                    'llama2_70b_4096' 'llama2_70b_4096_synthetic' 'llama2_70b_4096_sc' \
                    'llama2_70b_4096_sc_real_data_tfds' 'llama2_70b_4096_sc_real_data_grain' \
                    'llama2_70b_4096_sc_real_data_grain_checkpoint' 'llama2_70b_4096_rd_lr' \
                    'llama3_8b_8192' 'llama3_70b_8192' 'llama3_1_405b_8192_fsdp_dcn' \
                    'llama3_1_405b_8192_pure_fsdp_ici' 'llama3_1_8b_8192' 'llama3_1_8b_8192_bs5' \
                    'llama3_1_8b_8192_no_collective_matmul' 'llama3_1_70b_8192' 'llama3_1_70b_8192_bs2' \
                    'llama3_1_70b_8192_bs2_bfloat16_no_collective_matmul' 'llama3_1_70b_8192_bs4' \
                    'llama3_1_70b_8192_synthetic' 'llama3_1_70b_8192_rd_grain' \
                    'llama3_1_70b_8192_synthetic_ckpt' 'llama3_1_70b_8192_rd_ckpt_grain' \
                    'llama3_1_70b_8192_pw_lr_rd' 'llama3_1_70b_8192_iter_real_data_and_checkpointing_tfds' \
                    'llama3_1_70b_8192_synth' 'llama3_1_70b_129024' 'mistral_7b' 'mixtral_8x7b_dropless' \
                    'mixtral_8x7b_dropped' 'mixtral_8x7b_dropped_int8' 'mixtral_8x22b_dropped' \
                    'deepseek_v3_ep16' 'gemma2_9b_8192' 'gemma2_27b_8192' \
                    'gemma3_12b_32768_v6e256' 'gemma3_12b_32768_2x_v6e256' \
                    'gemma3_12b_32768_4x_v6e256' 'llama3_1_70b_131072' 'custom_moe_700b' \
                    'llama3_1_405b_8192_v5p_1024' 'deepseek_v3_ep_256_v5p_512' \
                    'llama4_scout_dropless_v5p_256' 'llama4_maverick_dropless_v5p_256' 'llama2_70b_v5p_128' \
                    'llama2_7b_v5p_128' 'gpt_3_175b_v5p_128' 'gpt_3_175b_v5p_128_sc' 'deepseek3_671b_v5p_1024' \
                    'default_16b_v5e_256' 'default_32b_v5e_256' 'default_64b_v5e_256' 'default_128b_v5e_256' \
                    'gpt_3_175b_v5e_256' 'llama2_7b_v5e_256' 'llama2_13b_v5e_256' 'llama2_70b_v5e_256' \
                    'llama3_1_8b_8192_v5e_256' 'deepseek_v3_ep_256_v5p_512_c4mlperf'; do
  python3 -m benchmarks.benchmark_runner xpk \
      --base_docker_image='maxtext_base_image' \
      --project="${PROJECT?}" \
      --zone="${ZONE?}" \
      --cluster_name="${CLUSTER_NAME?}" \
      --device_type='v6e-256' \
      --num_slices='1' \
      --base_output_directory="${OUTPUT_DIR?}" \
      --model_name="$model_name" && \
  printf '%s\n' "$model_name" >> 'successful_cluster_runs.txt' || \
  printf '%s\n' "$model_name" >> 'failed_cluster_runs.txt'
done

$ wc -l 'successful_cluster_runs.txt'
67 successful_cluster_runs.txt

$ cat 'successful_cluster_runs.txt'
# [… omitted in lieu of succeeding markdown list]

TL;DR version, these worked locally:

  • default
  • mistral-7b
  • deepseek3-tiny
  • gemma-2b
  • gemma2-2b
  • qwen3-0.6b
  • qwen3-4b
  • qwen3-4b-thinking-2507
  • gpt3-6b
  • gpt3-52k

And these worked via xpk on the cluster:

  • default_basic_1
  • default_32
  • default_64
  • default_128
  • default_256
  • default_512
  • gpt_3_175b
  • gpt_3_175b_bf16
  • llama2_7b_4096
  • llama2_70b_4096
  • llama2_70b_4096_synthetic
  • llama2_70b_4096_sc
  • llama2_70b_4096_sc_real_data_tfds
  • llama2_70b_4096_sc_real_data_grain
  • llama2_70b_4096_sc_real_data_grain_checkpoint
  • llama2_70b_4096_rd_lr
  • llama3_8b_8192
  • llama3_70b_8192
  • llama3_1_405b_8192_fsdp_dcn
  • llama3_1_405b_8192_pure_fsdp_ici
  • llama3_1_8b_8192
  • llama3_1_8b_8192_bs5
  • llama3_1_8b_8192_no_collective_matmul
  • llama3_1_70b_8192
  • llama3_1_70b_8192_bs2
  • llama3_1_70b_8192_bs2_bfloat16_no_collective_matmul
  • llama3_1_70b_8192_bs4
  • llama3_1_70b_8192_synthetic
  • llama3_1_70b_8192_rd_grain
  • llama3_1_70b_8192_synthetic_ckpt
  • llama3_1_70b_8192_rd_ckpt_grain
  • llama3_1_70b_8192_pw_lr_rd
  • llama3_1_70b_8192_iter_real_data_and_checkpointing_tfds
  • llama3_1_70b_8192_synth
  • llama3_1_70b_129024
  • mistral_7b
  • mixtral_8x7b_dropless
  • mixtral_8x7b_dropped
  • mixtral_8x7b_dropped_int8
  • mixtral_8x22b_dropped
  • deepseek_v3_ep16
  • gemma2_9b_8192
  • gemma2_27b_8192
  • gemma3_12b_32768_v6e256
  • gemma3_12b_32768_2x_v6e256
  • gemma3_12b_32768_4x_v6e256
  • llama3_1_70b_131072
  • custom_moe_700b
  • llama3_1_405b_8192_v5p_1024
  • deepseek_v3_ep_256_v5p_512
  • llama4_scout_dropless_v5p_256
  • llama4_maverick_dropless_v5p_256
  • llama2_70b_v5p_128
  • llama2_7b_v5p_128
  • gpt_3_175b_v5p_128
  • gpt_3_175b_v5p_128_sc
  • deepseek3_671b_v5p_1024
  • default_16b_v5e_256
  • default_32b_v5e_256
  • default_64b_v5e_256
  • default_128b_v5e_256
  • gpt_3_175b_v5e_256
  • llama2_7b_v5e_256
  • llama2_13b_v5e_256
  • llama2_70b_v5e_256
  • llama3_1_8b_8192_v5e_256
  • deepseek_v3_ep_256_v5p_512_c4mlperf

Manually ran this to test also:

python3 -m MaxText.decode MaxText/configs/base.yml \
    model_name=llama2-7b \
    tokenizer_path=src/MaxText/assets/tokenizer_llama3.tiktoken \
    tokenizer_type=tiktoken \
    scan_layers=false \
    per_device_batch_size=1 \
    ici_fsdp_parallelism=1 \
    ici_autoregressive_parallelism=-1 \
    max_prefill_predict_length=128 \
    max_target_length=256 \
    prompt="I love to" \
    attention=dot_product

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending line lengths which make it hard to read

@SamuelMarks SamuelMarks requested a review from NuojCheng as a code owner August 1, 2025 16:23
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do the file name letters stand for? types_j, types_g, etc.

@SamuelMarks
Copy link
Collaborator Author

What do the file name letters stand for? types_j, types_g, etc.

@bvandermoon Oh ignore that. Each is a different attempt (in lexicographical order). All will be removed with a singular types.py to take its place when I rebase this PR to 1 commit.

@SamuelMarks SamuelMarks force-pushed the pydantic branch 15 times, most recently from 8f66c7a to b2b4bf7 Compare November 12, 2025 22:25
…configuration files in MaxText ; [src/MaxText/pyconfig.py] New temporary wrapper to not break existing API ; [src/MaxText/pyconfig_og.py] Move original version here ; [src/MaxText/configs/__init__.py] Make this a module ; [tests/pyconfig_test.py] Import from og pyconfig ; [*requirements*.txt] Add pydantic requirement ; [tests/configs_test.py] Test every config in the repo ; [tests/configs_value_test.py] Test various config values
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants